import numpy as np
import pandas as pd
import sys
import os
import anndata
from plotnine import *
import scanpy as sc
import itertools
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import pairwise_distances


def gc(seqs, seq_col='Sequence'):
    """
    Calculate the GC fraction of a DNA sequence or list of sequences.

    Args:
        seqs (list, str): The DNA sequences to calculate the GC content of.

    Returns:
        (list, float): The fraction of the sequence comprised of G and C bases.

    """
    if isinstance(seqs, str):
        return float(seqs.count("G") + seqs.count("C")) / len(seqs)

    elif (isinstance(seqs, pd.DataFrame)):
        return [gc(seq) for seq in list(seqs[seq_col])]


def kmer_frequencies(seqs, k, normalize=False):
    """
    Get frequencies of all kmers of length k in a sequence.

    Args:
        seqs (list): DNA sequences.
        k (int): The length of the k-mer.
        normalize (bool, optional): Whether to normalize the histogram so that the values sum to 1.
        Default is False.

    Returns:
        (pd.DataFrame): A dataframe of shape (kmers x sequences), containing
        the frequency of each k-mer in the sequence.
    """
    # Get all possible kmers
    kmers = ["".join(kmer) for kmer in itertools.product(['A', 'C', 'G', 'T'], repeat=k)]

    if isinstance(seqs, str):
        assert k <= len(
            seqs
        ), "k must be smaller than or equal to the length of the sequence"

        # Dictionary of all possible kmers
        output = {"".join(kmer): seqs.count(kmer) for kmer in kmers}

        # Make dataframe with kmers as rows
        output = pd.DataFrame.from_dict(output, orient="index")

        # Normalize
        if normalize:
            output[0] /= len(seqs) - k + 1

        return output

    if isinstance(seqs, list):
        output = pd.concat(
            [kmer_frequencies(seq, k, normalize) for seq in seqs], axis=1
        ).T.reset_index(drop=True)
        return anndata.AnnData(output)


def select_length(df, seq_col='seq', target=80):
    return df[df[seq_col].apply(len) == target]


def drop_Ns(df, seq_col='seq'):
    return df[~df[seq_col].apply(lambda x: 'N' in x)]


def one_nn(ad, pca=True, group_col='Group', label_col='label'):
    if pca:
        mat = ad.obsm['X_pca']
    else:
        mat = ad.X

    # Get neighbors
    nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(mat)
    distances, indices = nbrs.kneighbors(mat, n_neighbors=2)
    ad.obs['one_nn_group'] = ad.obs[group_col][indices[:, 1]].tolist()

    # Calculate proportion
    df = ad.obs[[group_col, 'one_nn_group', label_col]].value_counts().reset_index()
    df = df.pivot_table(index=[label_col, group_col], columns='one_nn_group', values='count').fillna(0).astype(int)
    df = df.div(df.sum(axis=1), axis=0)
    return df


def ref_dist(ad, pca=True, group_col='Group', ref_group='Test Set'):

    groups = ad.obs[group_col].unique()

    if pca:
        ref_X = ad.obsm['X_pca'][ad.obs[group_col]==ref_group, :]
    else:
        ref_X = ad.X[ad.obs[group_col]==ref_group, :]
            
    for group in groups:

        if group != ref_group:
            if pca:
                group_X = ad.obsm['X_pca'][ad.obs[group_col]==group, :]
            else:
                group_X = ad.X[ad.obs[group_col]==group, :]
                
            distances = pairwise_distances(group_X, ref_X, metric='euclidean')
            ad.obs.loc[ad.obs[group_col]==group, 'ref_dist'] = distances.min(1)
        else:
            # ref to ref
            dlist = []
            distances = pairwise_distances(ref_X, ref_X, metric='euclidean')
            for i, row in enumerate(distances):
                row = np.delete(row, i)
                dlist.append(row.min())
            ad.obs.loc[ad.obs[group_col]==group, 'ref_dist'] = dlist
    return ad


def within_group_knn_dist(ad, n_neighbors=10, group_col='Group', use_pca=False):
    for group in ad.obs[group_col].unique():
        in_group = ad.obs[group_col]==group
        if use_pca:
            group_X = ad.obsm['X_pca'][in_group, :]
        else:
            group_X = ad.X[in_group, :]
        nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(group_X)
        distances, indices = nbrs.kneighbors(group_X)
        ad.obs.loc[in_group, 'KNN Distance'] = distances.mean(1)
    return ad


def min_edit_distance_from_reference(df, reference_group, group_col='Group'):
    """
    For each sequence in non-reference groups, find the smallest edit distance between that sequence
    and the sequences in the reference group

    Args:
        df (pd.DataFrame): Dataframe containing sequences in column "Sequence"
        reference_group (str): ID for the group to use as reference
        group_col (str): Name of the column containing group IDs

    Returns:
        edit (np.array): list of edit distance between each sequence  and its closest reference sequence.
        Set to 0 for reference sequences
    """
    # List nonreference groups
    groups = df[group_col].unique()
    nonreference_groups = list(groups[groups!=reference_group])

    # Create empty array
    edit = np.zeros(len(df), dtype=int)

    # Get reference sequences
    reference_seqs=df.Sequence[df[group_col]==reference_group].tolist()

    # Calculate distances
    for group in nonreference_groups:
        in_group = df[group_col]==group
        group_seqs = df.Sequence[in_group].tolist()
        group_edit = min_edit_distance(group_seqs, reference_seqs)
        edit[in_group] = group_edit

    return edit
